#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/1/7 16:49
# @Author  : zzp
# @File    : visualization
# @Software: PyCharm
import torch
from matplotlib import image as mpimg, pyplot as plt


def visualize(path, target, dataset, ans, logger):
    """
    可视化单张图片
    @return:
    """
    indices = torch.nonzero(target == 1, as_tuple=True)[1]
    indices = list(set(indices.tolist()))  # 去重
    img = mpimg.imread(path)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    labels = [dataset.classnames[i] for i in indices]  # 获取内容相应标签
    logger.info(labels)  # 输出图片对应的标签
    # 输出ans订阅的内容
    string = [dataset.clients[node].content_name for node in ans]
    logger.info(" ".join(string))
    img = mpimg.imread(path)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
